{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, numpy, json, os\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from netdissect.progress import verbose_progress, default_progress\n",
    "from netdissect.nethook import edit_layers\n",
    "from netdissect.modelconfig import create_instrumented_model\n",
    "from netdissect.zdataset import standard_z_sample, z_sample_for_model\n",
    "from netdissect.easydict import EasyDict\n",
    "from netdissect.aceoptimize import ace_loss\n",
    "from netdissect.segmenter import UnifiedParsingSegmenter\n",
    "from netdissect.fullablate import measure_full_ablation\n",
    "from netdissect.plotutil import plot_tensor_images, plot_max_heatmap\n",
    "import netdissect.aceoptimize\n",
    "import netdissect.fullablate\n",
    "\n",
    "verbose_progress(True)\n",
    "\n",
    "layer = 'layer4'\n",
    "dissectdir = 'dissect/churchoutdoor'\n",
    "with open(os.path.join(dissectdir, 'dissect.json')) as f:\n",
    "    dissection = EasyDict(json.load(f))\n",
    "\n",
    "segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')\n",
    "model = create_instrumented_model(dissection.settings)\n",
    "\n",
    "classname = 'tree'\n",
    "classnum = segmenter.get_label_and_category_names()[0].index((classname, 'object'))\n",
    "\n",
    "edit_layers(model, [layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "unit_count = 20\n",
    "\n",
    "# then make a batch of 10 images\n",
    "big_sample = z_sample_for_model(model, 1000, seed=3)\n",
    "big_dataset = TensorDataset(big_sample)\n",
    "big_loader = DataLoader(big_dataset, batch_size=10)\n",
    "\n",
    "\n",
    "\n",
    "lrec = [l for l in dissection.layers if l.layer == layer][0]\n",
    "rrec = [r for r in lrec.rankings if r.name == 'tree-iou'][0]\n",
    "iou_scores = -torch.tensor(rrec.score)\n",
    "iou_values, iou_order = (-iou_scores).sort(0)\n",
    "iou_values = -iou_values\n",
    "iou_ablation = torch.zeros_like(iou_scores)\n",
    "iou_ablation[iou_order[:unit_count]] = 1\n",
    "\n",
    "# load ablation from the tree model snapshot\n",
    "snapdir = os.path.join(dissectdir, layer, 'ace_reg0.005', 'tree', 'snapshots')\n",
    "data = torch.load(os.path.join(snapdir, 'epoch-9.pth'))\n",
    "learned_scores = data['ablation'][0,:,0,0]\n",
    "_, learned_order = (-learned_scores - iou_scores.cuda() * 1e-5).sort(0)\n",
    "learned_values = learned_scores[learned_order]\n",
    "learned_ablation = torch.zeros_like(learned_scores)\n",
    "learned_ablation[learned_order[:unit_count]] = learned_scores[learned_order[:unit_count]]\n",
    "\n",
    "progress = default_progress()\n",
    "\n",
    "# (1) call ace_loss to get baseline tree pixels in batch\n",
    "with torch.no_grad():\n",
    "    baseline_loss = 0\n",
    "    for [small_sample] in progress(big_loader):\n",
    "        baseline_loss += ace_loss(segmenter, classnum, model, layer,\n",
    "                torch.zeros_like(learned_scores)[None,:,None,None],\n",
    "                torch.zeros_like(learned_scores)[None,:,None,None],\n",
    "                small_sample, 0, 0, 0, run_backward=False,\n",
    "                discrete_pixels=True,\n",
    "                discrete_units=0,\n",
    "                # mixed_units=True,\n",
    "                ablation_only=True,\n",
    "                fullimage_measurement=True,\n",
    "                fullimage_ablation=True)\n",
    "\n",
    "# (2) apply 20 unit iou ablation and call ace_loss to see the difference\n",
    "with torch.no_grad():\n",
    "    iou_loss = 0\n",
    "    for [small_sample] in progress(big_loader):\n",
    "        iou_loss += ace_loss(segmenter, classnum, model, layer,\n",
    "                torch.zeros_like(iou_scores)[None,:,None,None].cuda(), # high_replacement\n",
    "                iou_scores[None,:,None,None].cuda(),  # ablation\n",
    "                small_sample, 0, 0, 0, run_backward=False,\n",
    "                discrete_pixels=True,\n",
    "                discrete_units=20,\n",
    "                ablation_only=True,\n",
    "                fullimage_measurement=True,\n",
    "                fullimage_ablation=True)\n",
    "\n",
    "# (3) apply 20 unit learned ablation and call ace_loss to see the difference\n",
    "with torch.no_grad():\n",
    "    learned_loss = 0\n",
    "    for [small_sample] in progress(big_loader):\n",
    "        learned_loss += ace_loss(segmenter, classnum, model, layer,\n",
    "                torch.zeros_like(learned_scores)[None,:,None,None].cuda(),\n",
    "                learned_scores[None,:,None,None].cuda(),\n",
    "                small_sample, 0, 0, 0, run_backward=False,\n",
    "                discrete_pixels=True,\n",
    "                discrete_units=20,\n",
    "                mixed_units=True,\n",
    "                ablation_only=True,\n",
    "                fullimage_measurement=True,\n",
    "                fullimage_ablation=True)\n",
    "\n",
    "print('iou', 1 - iou_loss / baseline_loss, 'learned', 1 - learned_loss / baseline_loss)\n",
    "# There should be some ratio!!!\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# (4) call measure_full_ablation on both the iou order to get the same raio\n",
    "iou_measurements = measure_full_ablation(segmenter, big_loader, model, classnum, layer,\n",
    "                                        iou_order[:unit_count], torch.ones_like(iou_values[:(unit_count)]))\n",
    "\n",
    "# (5) and onthe learned ablation to get the same ratio!\n",
    "learned_measurements = measure_full_ablation(segmenter, big_loader, model, classnum, layer,\n",
    "                                        learned_order[:unit_count], learned_values[:(unit_count)])\n",
    "\n",
    "print(iou_measurements)\n",
    "print(learned_measurements)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(1 - iou_measurements / iou_measurements[0], 1 - learned_measurements / learned_measurements[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot((1 - iou_measurements / iou_measurements[0]).numpy())\n",
    "plt.plot((1 - learned_measurements / learned_measurements[0]).numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learned_measurements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learned_measurements, baseline_loss  * 16 * 16, learned_loss * 16 * 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "big_sample[0,:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "standard_z_sample(3, 512, seed=2)[0,:10]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}